import os

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import json
import re
from tqdm import tqdm
os.environ["CUDA_VISIBLE_DEVICES"] = "0,1"

def load_model(model_path):
    model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    return model, tokenizer
    
def construct_context(sample, examples):
    context = ""
    flag = 1
    for example in examples:
        if(flag):
            context += f'''Below is an input may be from pre-training corpus. if the input is seen in the pre-training step, the answer is Yes, otherwise, it is No. Please provide an answer.\n\n'''
            flag=0            
        context += f'''input: {example['input']}\nAnswer: {example['output']}\n\n'''

    input = f'''input: {sample['text']}\nAnswer:'''
    context += input
    # print(context)
    return context

def get_data_json(data_path):
    with open(data_path, "r") as f:
        data = json.load(f)
    return data


def get_examples(dataset):
    examples = []
    pos = 0
    neg = 0
    for data in dataset:
        if(data['label'] == 0 and neg < 2):
            examples.append({"input": f"{data['text']}", "output": "No"})
            neg += 1
        elif(data['label'] == 1 and pos < 2):
            examples.append({"input": f"{data['text']}", "output": "Yes"})
            pos += 1
        if (pos == 2 and neg == 2):
            break
    return examples

def generate(model, tokenizer, context):
    inputs = tokenizer(text=context, return_tensors="pt")
    output_tokens = model.generate(inputs.input_ids, max_new_tokens=10, num_return_sequences=1, do_sample=False)
    output_text = tokenizer.decode(output_tokens[0], skip_special_tokens=True)
    predicted_output = output_text[len(context)-20:]
    # predicted_output = output_text
    return predicted_output

def load_data(data_path):
    with open(data_path, "r") as f:
        data = json.load(f)
    return data

def icl(model, tokenizer, train_data, test_data):
    results = []
    labels = []
    examples = get_examples(train_data)
    for sample in tqdm(test_data):
        context = construct_context(sample, examples)
        result = generate(model, tokenizer, context)
        labels.append(sample["label"])
        results.append(result)
    results = extract_answer(results)
    accuracy, precision, recall, f1, count = evaluate(results, labels)
    print(f"Accuracy: {accuracy}")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")
    print(f"F1: {f1}")
    print(f"Count: {count}")
    

def extract_answer(generated_text):
    results = []
    for result in generated_text:
        math = re.search(r"Answer:\s*(Yes|No)", result)
        if math:
            results.append(math.group(1).strip())
        else:
            results.append(None)
    return results

def evaluate(results, labels):
    correct = 0
    tp = 0
    tn = 0
    fp = 0
    fn = 0
    count = 0
    for i in range(len(results)):
        if results[i] is not None:
            predict = 0 if results[i] == 'No' else 1
            if predict == labels[i]:
                correct += 1
                if predict == 1:
                    tp += 1
                else:
                    tn += 1
            else:
                if predict == 1:
                    fp += 1
                else:
                    fn += 1
        else:
            print(f"Error: no answer found in the generated text.")
            count+=1
    all_total = len(results) - count
    accuracy = 0 if all_total == 0 else correct / all_total
    precision = 0 if tp + fp == 0 else tp / (tp + fp)
    recall = 0 if tp + fn == 0 else tp / (tp + fn)
    f1 = 0 if precision + recall == 0 else 2 * precision * recall / (precision + recall)
    return accuracy, precision, recall, f1, count     

if __name__ == "__main__":
    model_name_or_path = "/mnt/sharedata/ssd/users/zhanghx/models/llama-7b"
    train_data_path = "benchmarks/fine_tuning/train_data.json" 
    test_data_path = "benchmarks/fine_tuning/test_data.json"
    
    train_dataset = get_data_json(train_data_path)
    test_dataset = get_data_json(test_data_path)
    model, tokenizer = load_model(model_name_or_path)
    icl(model, tokenizer, train_dataset, test_dataset)